{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# 数据预处理\n",
    "\n",
    "在进行训练之前对数据集进行简单的预处理，顺便查看一下数据的基本格式"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入包\n",
    "import time\n",
    "import numpy as np\n",
    "import wfdb\n",
    "import ast\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pylab import mpl\n",
    "from scipy.fftpack import fft, ifft \n",
    "from scipy import signal"
   ]
  },
  {
   "source": [
    "## 读取文件\n",
    "\n",
    "基本的代码树设置如下所示：\n",
    "\n",
    "```\n",
    "PTB-XL-CLASSFIER\n",
    "├── code\n",
    "├── ├── dataPreprocess.ipynb\n",
    "├── data\n",
    "    ├── ptbxl_database.csv\n",
    "    ├── scp_statements.csv\n",
    "    ├── records100\n",
    "    │   ├── 00000\n",
    "    │   │   ├── 00001_lr.dat\n",
    "    │   │   ├── 00001_lr.hea\n",
    "    │   │   ├── ...\n",
    "    │   │   ├── 00999_lr.dat\n",
    "    │   │   └── 00999_lr.hea\n",
    "    │   ├── ...\n",
    "    │   └── 21000\n",
    "    │        ├── 21001_lr.dat\n",
    "    │        ├── 21001_lr.hea\n",
    "    │        ├── ...\n",
    "    │        ├── 21837_lr.dat\n",
    "    │        └── 21837_lr.hea\n",
    "    └── records500\n",
    "    ├── 00000\n",
    "    │     ├── 00001_hr.dat\n",
    "    │     ├── 00001_hr.hea\n",
    "    │     ├── ...\n",
    "    │     ├── 00999_hr.dat\n",
    "    │     └── 00999_hr.hea\n",
    "    ├── ...\n",
    "    └── 21000\n",
    "            ├── 21001_hr.dat\n",
    "            ├── 21001_hr.hea\n",
    "            ├── ...\n",
    "            ├── 21837_hr.dat\n",
    "            └── 21837_hr.hea\n",
    "```\n"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置读取文件路径\n",
    "path = '../data/'\n",
    "sampling_rate = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件并转换标签\n",
    "Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')\n",
    "Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_raw_data(df, sampling_rate, path):\n",
    "    if sampling_rate == 100:\n",
    "        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]\n",
    "    else:\n",
    "        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]\n",
    "    data = np.array([signal for signal, meta in data])\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取原始信号数据\n",
    "X = load_raw_data(Y, sampling_rate, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取scp_statements.csv中的诊断信息\n",
    "agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_df = agg_df[agg_df.diagnostic == 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_diagnostic(y_dic):\n",
    "    tmp = []\n",
    "    for key in y_dic.keys():\n",
    "        if key in agg_df.index:\n",
    "            tmp.append(agg_df.loc[key].diagnostic_class)\n",
    "    return list(set(tmp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加诊断信息\n",
    "Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y.columns"
   ]
  },
  {
   "source": [
    "## 绘图查看原始数据"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = (20.0, 10.0) \n",
    "plt.figure()\n",
    "plt.plot(X[0][:,0], linewidth=1.2)\n",
    "plt.grid(linestyle='--')\n",
    "# plt.yticks([])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg = X[4]\n",
    "titles = [\"I\", \"II\", \"III\", \"aVR\", \"aVL\", \"aVF\", \"V1\", \"V2\", \"V3\", \"V4\", \"V5\", \"V6\"]\n",
    "plt.rcParams['figure.figsize'] = (20.0, 20.0)\n",
    "plt.rcParams[\"axes.grid\"] = True\n",
    "plt.rcParams[\"grid.linestyle\"] = (0.1,0.1)\n",
    "plt.figure()\n",
    "for index in range(12):\n",
    "    plt.subplot(6,2,index+1)\n",
    "    plt.plot(ecg[:,index], linewidth=1)\n",
    "    \n",
    "    # plt.yticks(np.arange(np.min(ecg[:,index]), np.max(ecg[:,index]), 0.1))\n",
    "    plt.gca()\n",
    "    plt.title(titles[index])\n",
    "    plt.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "## 心电滤波去基线漂移及分段"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "def np_move_avg(a,n,mode=\"same\"):\n",
    "    return(np.convolve(a, np.ones((n,))/n, mode=mode))"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "source": [
    "### 五点平滑滤波"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg_original = X[40][:,0]\n",
    "ecg_filtered = np_move_avg(ecg_original, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fft变换查看频谱\n",
    "def ecg_fft_ana(ecg_original, sampling_rate):\n",
    "    fs = sampling_rate\n",
    "    ts = 1.0/fs\n",
    "    t = np.arange(0, 1, ts)\n",
    "    n = len(ecg_original)\n",
    "    k = np.arange(n)\n",
    "    t = n/fs\n",
    "    frq = k/t\n",
    "    frq = frq[range(int(n/2))]\n",
    "    fft_ecg = np.abs(fft(ecg_original))[range(int(n/2))]\n",
    "    return frq, fft_ecg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_spec_dif(ecg_original, ecg_filtered, title1=\"title1\", title2 = \"title2\"):\n",
    "    frq, fft_ecg = ecg_fft_ana(ecg_original, sampling_rate)\n",
    "    frq_filtered, fft_ecg_filtered = ecg_fft_ana(ecg_filtered, sampling_rate)\n",
    "    plt.figure()\n",
    "    plt.subplot(221)\n",
    "    plt.plot(ecg_original[:500])\n",
    "    plt.title(title1)\n",
    "    plt.subplot(222)\n",
    "    plt.plot(frq,fft_ecg)\n",
    "    plt.title(title1 + '`s spectrum')\n",
    "    plt.subplot(223)\n",
    "    plt.plot(ecg_filtered[:500])\n",
    "    plt.title(title2)\n",
    "    plt.subplot(224)\n",
    "    plt.plot(frq_filtered, fft_ecg_filtered)\n",
    "    plt.title(title2 + '`s spectrum')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_spec_dif(ecg_original, ecg_filtered, 'original ecg signal', 'filtered')"
   ]
  },
  {
   "source": [
    "### 陷波处理滤除工频干扰"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = sampling_rate  # 采样频率\n",
    "f0 = 50.0   # 要去除的工频干扰\n",
    "Q = 30.0    # 品质因数\n",
    "b, a = signal.iirnotch(f0, Q, fs)\n",
    "freq, h = signal.freqz(b, a, fs=fs)\n",
    "fig, ax = plt.subplots(2, 1, figsize=(8, 6))\n",
    "ax[0].plot(freq, 20*np.log10(abs(h)), color='blue')\n",
    "ax[0].set_title(\"Frequency Response\")\n",
    "ax[0].set_ylabel(\"Amplitude (dB)\", color='blue')\n",
    "ax[0].set_xlim([0, 100])\n",
    "ax[0].set_ylim([-25, 10])\n",
    "ax[0].grid()\n",
    "ax[1].plot(freq, np.unwrap(np.angle(h))*180/np.pi, color='green')\n",
    "ax[1].set_ylabel(\"Angle (degrees)\", color='green')\n",
    "ax[1].set_xlabel(\"Frequency (Hz)\")\n",
    "ax[1].set_xlim([0, 100])\n",
    "ax[1].set_yticks([-90, -60, -30, 0, 30, 60, 90])\n",
    "ax[1].set_ylim([-90, 90])\n",
    "ax[1].grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 工频干扰去除\n",
    "ecg_notch = signal.filtfilt(b, a, ecg_filtered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_spec_dif(ecg_filtered, ecg_notch, \"ecg_unnotch\", \"ecg_notch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 去除所有导联噪声\n",
    "channels = 12\n",
    "for index in range(len(X)):\n",
    "    for channel in range(channels):\n",
    "        X[index][:, channel] = np_move_avg(X[index][:, channel], 5)"
   ]
  },
  {
   "source": [
    "## 基线漂移\n",
    "\n",
    "利用高通滤波器去除基线漂移，基线漂移是一种低频干扰，频率范围通常小于1Hz，大部分集中在0.1Hz，通常出现在ST段和Q波附近\n",
    "\n",
    "这里信号的采样频率为500Hz，要滤除基线漂移，通过高通滤波器得到基线漂移，然后用原信号减去基线漂移得到去基线漂移的心电信号\n",
    "\n",
    "滤波器的设计如下：\n",
    "\n",
    "截止频率为0.1Hz，$wn=2*0.1/1000=0.0002$"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# b, a = signal.butter(8, 0.01, 'highpass')\n",
    "# baseline = signal.filtfilt(b, a, ecg_filtered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# diff = ecg_filtered-baseline\n",
    "# plt.figure()\n",
    "# plt.subplot(311)\n",
    "# plt.plot(ecg_filtered[:1000])\n",
    "# plt.subplot(312)\n",
    "# plt.plot(baseline[:1000])\n",
    "# plt.subplot(313)\n",
    "# plt.plot(diff[:1000])\n",
    "# plt.show()"
   ]
  },
  {
   "source": [
    ">参考的论文没有去除噪声，这里没有去除噪声，尝试去除基线漂移"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "### 心电分段提取\n",
    "\n",
    "根据心电12导联的$2$导联通道的R波分段，R波前的150个数据和R波之后的350个数据"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg = X[np.random.randint(len(X))][:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_diff(ecg):\n",
    "    ecg_diff = np.zeros(len(ecg))\n",
    "    for i in range(len(ecg)-1):\n",
    "        ecg_diff[i] = ecg[i+1] - ecg[i]\n",
    "    ecg_diff[len(ecg)-1] = ecg[len(ecg)-1]\n",
    "    return ecg_diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 检测R波函数\n",
    "def checkR(ecg):\n",
    "    max_val = np.max(ecg)\n",
    "    min_val = np.min(ecg)\n",
    "    threshold_val = (max_val-min_val)*0.7 + min_val\n",
    "    index = []\n",
    "    for i in range(1, len(ecg)-2):\n",
    "        # 满足差分阈值条件\n",
    "        if ecg[i] == np.max(ecg[i-1:i+2]) and ecg[i] > threshold_val:\n",
    "            # 满足心率间隔60-160\n",
    "            if index != []:\n",
    "                if i-index[-1] <= 60.0/60.0*sampling_rate and i-index[-1] >= 60.0/160.0*sampling_rate:\n",
    "                    index.append(i)\n",
    "            else:\n",
    "                index.append(i)\n",
    "    return np.array(index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = checkR(ecg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 检查R波检测是否正确\n",
    "plt.figure()\n",
    "plt.plot(ecg)\n",
    "for i in range(len(index)):\n",
    "    plt.scatter(index[i], ecg[index[i]],c='r')\n",
    "    plt.annotate('R',(index[i], ecg[index[i]]), fontsize=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ecg_diff = get_diff(ecg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.figure()\n",
    "# plt.subplot(211)\n",
    "# plt.plot(ecg)\n",
    "# plt.subplot(212)\n",
    "# plt.plot(ecg_diff)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分段函数\n",
    "def splitByR(ecg):\n",
    "    index = checkR(ecg)              \n",
    "    ecg_rhythm = None\n",
    "    for i in range(len(index)):\n",
    "        # 提取出一段\n",
    "        if index[i]>200 and index[i]<1000-350:\n",
    "            ecg_rhythm = ecg[index[i]-150:index[i]+350]\n",
    "            continue\n",
    "    return ecg_rhythm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg_rhythm = splitByR(ecg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg_rhythm.size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(ecg_rhythm)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对所有的信号的所有通道进行分段提取\n",
    "channels = 12\n",
    "test_size = len(X)\n",
    "ecg_rhythms = np.zeros([test_size, 500, 12])\n",
    "\n",
    "start_time = time.time()\n",
    "first_time = start_time\n",
    "for index in range(test_size):\n",
    "    if index%1000 == 0:\n",
    "        end_time = time.time()\n",
    "        print(\"finish %d in %d s\\n\" % (index, end_time - start_time))\n",
    "        start_time = time.time()\n",
    "    R_index = checkR(X[index][:,1])\n",
    "    for i in range(len(R_index)):\n",
    "        # 提取出一段\n",
    "        if R_index[i]>200 and R_index[i]<1000-350:\n",
    "            ecg_rhythms[index][:, :] = X[index][R_index[i]-150:R_index[i]+350,:]\n",
    "            continue\n",
    "\n",
    "end_time = time.time()\n",
    "\n",
    "print('time cost:%d s'%(end_time-first_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ecg_rhythms.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "for i in range(12):\n",
    "    plt.subplot(6, 2, i+1)\n",
    "    plt.plot(ecg_rhythms[40,:,i])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "false_count = 0\n",
    "Y['Runconfirmed'] = 0\n",
    "for index in range(len(ecg_rhythms)):\n",
    "    if ecg_rhythms[index].any() == np.zeros([500, 12]).any():\n",
    "        false_count += 1\n",
    "        Y['Runconfirmed'][index] = 1\n",
    "false_count"
   ]
  },
  {
   "source": [
    "## 划分数据集"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into train and test\n",
    "test_fold = 10\n",
    "# Train\n",
    "X_train = ecg_rhythms[(Y.strat_fold != test_fold)&(Y.Runconfirmed !=1)]\n",
    "y_train = Y[(Y.strat_fold != test_fold)&(Y.Runconfirmed !=1)].diagnostic_superclass\n",
    "# Test\n",
    "X_test = ecg_rhythms[(Y.strat_fold == test_fold)&(Y.Runconfirmed !=1)]\n",
    "y_test = Y[(Y.strat_fold == test_fold)&(Y.Runconfirmed !=1)].diagnostic_superclass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = '../data/numpy_data/'\n",
    "np.save(save_path+'X_train.npy', X_train)\n",
    "np.save(save_path+'y_train.npy', np.array(y_train))\n",
    "np.save(save_path+'X_test.npy', X_test)\n",
    "np.save(save_path+'y_test.npy', np.array(y_test))"
   ]
  }
 ]
}